# ------------------------------------------------------------------------------
# Functions for standardizing graphic aesthetics throughout repo
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To use: in script, set a <- modules::use(here("lib", "aesthetics.R"))
#         then functions maybe called with a$get_font, e.g.
# ------------------------------------------------------------------------------

# Load libraries ---------------------------------------------------------------
import(showtext) # use Optima font in plots
import(ggplot2) # necessary to run tile_plot function
import(ggthemes) # colorblind palette for tile_plot function
import(testit)
import(sysfonts) # add a font
import(rprojroot) # get root for loading font
import(dplyr)
import(glue)
import(grDevices) # colorRamp

# Palette ----------------------------------------------------------------------
disc_palette <- c("#475C7A", "#FCBB6D", "#AB6C82", "#685D79", "#D49478")
disc_palette_light <- c("#5B759A", "#FDD19B", "#BC8A9C", "#817595", "#E2B5A2")
ordered_palette <- c("#FCBB6D", "#D49478", "#AB6C82", "#685D79", "#475C7A")
ordered_palette_light <- c("#FDD19B", "#E2B5A2", "#BC8A9C", "#817595", "#5B759A")

cont_ramp <- colorRampPalette(ordered_palette)
cont_ramp_light <- colorRampPalette(ordered_palette_light)

# Function definitions ---------------------------------------------------------
get_font <- function(fontname, fontpath){

  # INPUT: name that will be used to access font, file path to .ttf file for desired font
  # OUTPUT: there is no output, but the font will be available for use after the function call

  font_add(family = fontname, regular = fontpath)
  showtext_auto()
}

tile_plot <- function(mean_outcomes, palette = disc_palette, large_font = FALSE, ...){

  # INPUT: A table of grouped outcomes with the following columns: beta,
  #   x_var, grouping, and std.error. Additional columns will not change the output.
  # OUTPUT: A ggplot object that maps one line for each group of grouping, with x_var
  #   on the x axis and outcome_rate on the y axis, with 95% confidence intervals
  #   based on the std.error variable. The plot is unlabeled.

  # get optima
  root <- find_root(has_file(".git/index"))
  get_font("Optima", file.path(root, "lib", "optima.ttf"))

  requirements <- c("beta", "x_var", "beta_lo", "beta_hi", "grouping")
  given <- names(mean_outcomes)
  assert("Error: Mean outcomes table is missing required column",
          all(requirements %in% given))
  num_grps <- length(mean_outcomes$grouping %>% unique)
  legend_pos <- ifelse(num_grps > 1, "bottom", "none")
  x_vals <- unique(mean_outcomes$x_var)
  x_labs <- x_vals * (100/length(x_vals))
  font_size <- ifelse(large_font, 60, 40)

  gg <- ggplot(mean_outcomes,
              aes(x = x_var, y = beta,
                  ymin = beta_lo,
                  ymax = beta_hi,
                  group = as.factor(grouping),
                  color = as.factor(grouping),
                  fill = as.factor(grouping))) +
          geom_ribbon(color = NA, alpha = 0.4) +
          geom_line(linetype = "dashed") +
          geom_point() +
          theme_bw() +
          scale_x_continuous(breaks = x_vals, labels = x_labs) +
          theme(legend.position = legend_pos,
                text = element_text(family = "Optima", size = font_size)) +
          scale_color_manual(values = palette) +
          scale_fill_manual(values = palette)
          # scale_color_colorblind() +
          # scale_fill_colorblind()

  return(gg)
}

plot_contr_exp_curve <- function(
  df, quantile_rates, yhat_var, cohort_spl, outcome_var_quantiles,
  outcome_var_curve, rm_or_add, smooth = TRUE, nocp
){
  # get optima
  root <- find_root(has_file(".git/index"))
  get_font("Optima", file.path(root, "lib", "optima.ttf"))

  print(outcome_var_quantiles)
  outcome_lab <- case_when(
    grepl("stent", outcome_var_quantiles) ~ "Overall Intervention Rate",
    grepl("mace", outcome_var_quantiles) ~ "Overall MACE-Untested (30 Day) Rate",
    TRUE ~ outcome_var_quantiles
  )
  subtitle <- glue("Quant Outcome Var: {outcome_var_quantiles}; Curve Outcome Var: {outcome_var_curve}")
  subtitle <- glue("{subtitle}\n Y-Hat Var: {yhat_var}")
  caption <- glue("{cohort_spl} cohort, excluding patients with chronic illness, 80+ etc")
  if(nocp){caption <- glue("{caption}; Excluding patients with chief complaint chest or esophagus pain")}

  gg <- ggplot(
    df, aes(
      x = test_rate, y = outcome_rate, ymin = outcome_rate_lo,
      ymax = outcome_rate_hi
    )
  ) +
    labs(#title = glue("{outcome_lab} of Algorithm-{rm_or_add} Tests"),
      x = "Test Rate",
      y = outcome_lab#,
      # subtitle = subtitle,
      # caption = caption
      ) +
    geom_point(
      data = quantile_rates, aes(
        x = test_rate, y = outcome_rate
      ), color = disc_palette[1]
    ) +
    geom_linerange(
      data = quantile_rates, aes(
        x = test_rate, ymin = outcome_rate_lo, ymax = outcome_rate_hi
      ), color = disc_palette[1]
    ) +
    geom_text(data = quantile_rates, aes(label = name, hjust = 0, vjust = 0),
              nudge_x = 0.0005, size = 12) +
    theme_bw() +
    theme(text = element_text(family = "Optima", size = 40))

  if(smooth){
    gg <- gg +
      geom_smooth(method = "gam", formula = y ~ s(x, bs = "cs"), color = disc_palette[1])
  }else{
    gg <- gg + geom_point(color = disc_palette[1]) +
      geom_pointrange(alpha = 0.2, color = disc_palette[1])
  }
  return(gg)
}
